{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "48f053c3",
   "metadata": {},
   "source": [
    "# Neural SDEs Cookbook\n",
    "In this notebook, we explore the Neural SDE module in torchdyn. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "579b6a9a",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "The autoreload extension is already loaded. To reload it, use:\n",
      "  %reload_ext autoreload\n"
     ]
    }
   ],
   "source": [
    "import torch\n",
    "\n",
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "dd8af921",
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "sys.path.append('..')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "82ab358e",
   "metadata": {},
   "outputs": [],
   "source": [
    "from torchdyn.utils import plot_3D_dataset\n",
    "from torchdyn.datasets import ToyDataset\n",
    "from torch.utils.data import TensorDataset, DataLoader"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 39,
   "id": "0588f5b1",
   "metadata": {},
   "outputs": [],
   "source": [
    "device = torch.device(\"cuda:0\") if torch.cuda.is_available() else torch.device(\"cpu\")\n",
    "dry_run = False"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 40,
   "id": "6ead7f83",
   "metadata": {},
   "outputs": [],
   "source": [
    "d = ToyDataset()\n",
    "X, yn = d.generate(n_samples=512, dataset_type='moons', noise=.4)\n",
    "X_train = torch.Tensor(X).to(device)\n",
    "y_train = torch.LongTensor(yn.long()).to(device)\n",
    "train = TensorDataset(X_train, y_train)\n",
    "trainloader = DataLoader(train, batch_size=len(X), shuffle=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 41,
   "id": "51f0f680",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch.nn as nn\n",
    "import pytorch_lightning as pl\n",
    "\n",
    "class Learner(pl.LightningModule):\n",
    "    def __init__(self, t_span:torch.Tensor, model:nn.Module):\n",
    "        super().__init__()\n",
    "        self.model, self.t_span = model, t_span\n",
    "    \n",
    "    def forward(self, x):\n",
    "        return self.model(x)\n",
    "    \n",
    "    def training_step(self, batch, batch_idx):\n",
    "        x, y = batch      \n",
    "        t_eval, y_hat = self.model(x, t_span)\n",
    "        y_hat = y_hat[-1] # select last point of solution trajectory\n",
    "        loss = nn.CrossEntropyLoss()(y_hat, y)\n",
    "        return {'loss': loss}   \n",
    "    \n",
    "    def configure_optimizers(self):\n",
    "        return torch.optim.Adam(self.model.parameters(), lr=0.01)\n",
    "\n",
    "    def train_dataloader(self):\n",
    "        return trainloader"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 42,
   "id": "e1d3f3e2",
   "metadata": {},
   "outputs": [],
   "source": [
    "from torchdyn.core import NeuralSDE"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 43,
   "id": "561a12a7",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torchsde\n",
    "\n",
    "t_span = torch.linspace(0, 0.1, 100).to(device)\n",
    "size = X.shape\n",
    "\n",
    "bm = torchsde.BrownianInterval(\n",
    "    t0=t_span[0],\n",
    "    t1=t_span[-1],\n",
    "    size=size,\n",
    "    device=device,\n",
    "    levy_area_approximation='space-time'\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 44,
   "id": "98f8b93b",
   "metadata": {},
   "outputs": [],
   "source": [
    "f = nn.Sequential(nn.Linear(2, 64), nn.Tanh(), nn.Linear(64, 2))\n",
    "g = nn.Sequential(nn.Linear(2, 64), nn.Tanh(), nn.Linear(64, 2))\n",
    "\n",
    "\n",
    "model = NeuralSDE(f, \n",
    "                  g,\n",
    "                  solver='euler',\n",
    "                  noise_type='diagonal',\n",
    "                  sde_type='ito',\n",
    "                  sensitivity='autograd',\n",
    "                  t_span=t_span,\n",
    "                  bm=bm\n",
    "                 ).to(device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 45,
   "id": "f9cdcaf5",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "GPU available: True, used: True\n",
      "TPU available: False, using: 0 TPU cores\n",
      "IPU available: False, using: 0 IPUs\n",
      "HPU available: False, using: 0 HPUs\n",
      "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]\n",
      "\n",
      "  | Name  | Type      | Params\n",
      "------------------------------------\n",
      "0 | model | NeuralSDE | 644   \n",
      "------------------------------------\n",
      "644       Trainable params\n",
      "0         Non-trainable params\n",
      "644       Total params\n",
      "0.003     Total estimated model params size (MB)\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "f58285a148e743429dbc8f1e5472fd22",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Training: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "learn = Learner(t_span, model)\n",
    "if dry_run: trainer = pl.Trainer(min_epochs=1, max_epochs=1,gpus=1)\n",
    "else: trainer = pl.Trainer(min_epochs=200, max_epochs=300,gpus=1)\n",
    "trainer.fit(learn)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Environment (conda_node)",
   "language": "python",
   "name": "conda_node"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.0"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
